-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove or simplify hardcoded lists of device types #6235
Conversation
@@ -71,6 +71,11 @@ class IfrtComputationClient : public ComputationClient { | |||
|
|||
std::string GetDefaultDevice() const override; | |||
|
|||
torch_xla::DeviceType GetDeviceType() const override { | |||
return torch_xla::DeviceType( | |||
absl::AsciiStrToUpper(client_->platform_name())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The platform_name will be something like CUDA or TPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the simplification!
@@ -72,7 +72,7 @@ def is_xla_tensor(tensor): | |||
|
|||
|
|||
def parse_xla_device(device): | |||
m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device) | |||
m = re.match(r'([A-Z]+):(\d+)$', device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to think.. where would it fail if user has a typo on PJRT_DEVICE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device name here comes from the PjRtClient::platform_name
, which is not guaranteed to match the PJRT_DEVICE
name.
If there's a typo in PJRT_DEVICE
, the runtime will fail to initialize and throw an error.
} // namespace | ||
|
||
std::string DeviceType::XlaDeviceTypeToString(XlaDeviceType hw_type) { | ||
XLA_CHECK(hw_type != XlaDeviceType::PLUGIN) << "PLUGIN type name unknown"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when will this function be called? Are we not expceting this to be called on plug in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is used when constructing a device from the XlaDeviceType
enum type rather than the device name in a string. E.g. here:
xla/torch_xla/csrc/xla_backend_impl.cpp
Lines 183 to 184 in 9a4ef68
default_device_type_ = | |
std::make_shared<DeviceType>(static_cast<XlaDeviceType>(type)); |
If all we get is the placeholder XlaDeviceType::PLUGIN
type, then we don't know the actual platform_name
of the device type for toString
. I think this is a relatively rare case in our code, so I can try to remove it. I have a more ambitious refactoring effort going in #6261.
torch_xla/csrc/device.h
Outdated
type_name_(type_name) {} | ||
|
||
// TODO(wcromar): do we even want this default constructor? | ||
DeviceType() : DeviceType(XlaDeviceType::CPU){}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt like we should eventually let PyTorch/XLA to auto detect the device based on some rules(libtpu, cuda version etc). Can't really think of a case we want to default construct a CPU device...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm 90% sure this constructor is never called. Logically, the "default" device type is managed by xla_backend_impl
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we definitely don't need this. Upstream provides a better default constructor: https://github.com/pytorch/pytorch/blob/16d69290c6d037a25e32220b9517597d04dbd0bf/torch/csrc/lazy/backend/backend_device.cpp#L14-L15
I'm just going to delete this now instead of leaving the TODO.
Allows new device plugins without registering the device type in our code! See #6242 for broader context of this cleanup.
PLUGIN
device type for unknown devicesPjRtClient::platform_name
) in theDeviceType
, since we still need to pass the device name string across Python/C++ boundarydevkind
argument, since we only support one PJRT backend at a timeGPU
device typeTested with CPU plugin in #6253 where I change the platform name to
TEST
to simulate an unknown device type.test_operations.py
passes with that plugin.